{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6d742657-8e6a-439d-a65a-3c76e73c8810",
   "metadata": {},
   "source": [
    "# Logistic Regression On PPU"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff9c4ba5-a4a8-46ec-96e3-611ef5e71dc5",
   "metadata": {},
   "source": [
    "PPU(Privacy-Preserving Processing Unit)是可证、可度量的密态设备，使用MPC协议进行隐私计算，目前支持SPDZ-2k, ABY3协议。PPU提供了基础的张量计算能力，以[XLA HLO](https://www.tensorflow.org/xla)作为IR，因此用户可以使用支持XLA的前端进行编程，如JAX, TensorFLow, PyTorch。关于PPU的更多细节，请参考：[PPU官方文档](https://ppu.antfin-inc.com/index.html)。\n",
    "\n",
    "目前，PPU对JAX的支持比较完善，其他框架的支持没有得到充分测试，因此以下教程我们将使用JAX。\n",
    "\n",
    "通过以下教程，你将初步了解如何在PPU进行机器学习建模。\n",
    "\n",
    "这次教程我们选取了[Breast Cancer](https://archive.ics.uci.edu/ml/datasets/breast+cancer+wisconsin+(diagnostic))，这是一个简单的二分类数据集，提供了三十个特征以判断是良性还是恶性乳腺癌。我们的场景是纵向分布，两方各自提供十五个特征最后得到一个逻辑回归模型。\n",
    "\n",
    "第一步，我们先不考虑隐私计算的语义，直接利用JAX框架在明文下训练一个逻辑回归模型。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "629aad7f",
   "metadata": {},
   "source": [
    "## 明文逻辑回归"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9774618-798e-421f-9be0-4fb954d7c710",
   "metadata": {},
   "source": [
    "### 产生数据\n",
    "\n",
    "为了模拟纵向分布的场景，`load_dataset`将特征对半划分，一方持有一半特征和标签，另一方只持有特征。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "364a380e-9cea-42e3-8ab5-06635df97478",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_breast_cancer\n",
    "\n",
    "\n",
    "def load_dataset(return_label=False) -> (np.ndarray, np.ndarray):\n",
    "    features, label = load_breast_cancer(return_X_y=True)\n",
    "\n",
    "    if return_label:\n",
    "        return features[:, 15:], label\n",
    "    else:\n",
    "        return features[:, :15], None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b3893fc-5f6f-4508-a93b-1fee64b6806d",
   "metadata": {},
   "source": [
    "然后，我们对数据进行预处理，这有助于提升模型效果。这里我们使用sklearn对数据集进行均值方差归一化，使得各个特征符合标准正态分布。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "002e2e86-bacf-4ff0-b3d1-891e62e4ae2c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "def transform(data):\n",
    "    scaler = StandardScaler()\n",
    "    return scaler.fit_transform(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb400b1f-f516-4348-8812-e05a2e8f5e17",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 模型训练\n",
    "\n",
    "#### 模型定义\n",
    "首先，我们用JAX定义逻辑回归模型和损失函数。这是一份通用的代码，既可以CPU/GPU上运行，也可以在PPU运行，因此你可以很容易通过在CPU的运行结果检验PPU的正确性。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0ffb9d32-4150-41cb-a6e9-65cbccf6568a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "from jax import grad, jit, vmap\n",
    "from jax import random\n",
    "\n",
    "\n",
    "def sigmoid(x):\n",
    "    return 1 / (1 + jnp.exp(-x))\n",
    "\n",
    "\n",
    "# Outputs probability of a label being true.\n",
    "def predict(W, b, inputs):\n",
    "    return sigmoid(jnp.dot(inputs, W) + b)\n",
    "\n",
    "\n",
    "# Training loss is the negative log-likelihood of the training examples.\n",
    "def loss(W, b, inputs, targets):\n",
    "    preds = predict(W, b, inputs)\n",
    "    label_probs = preds * targets + (1 - preds) * (1 - targets)\n",
    "    return -jnp.mean(jnp.log(label_probs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7eaaf2b4-6491-471a-85fb-d63236e4a564",
   "metadata": {},
   "source": [
    "其次，我们定义模型的优化器，为了简单起见，这里使用了SGD优化器。需要注意的是，x1, x2分别代表alice和bob的特征，由于是垂直划分，因此我们需要将它们按样本维度进行拼接。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "71e01788-804f-4378-a268-b84c3940320a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import value_and_grad\n",
    "\n",
    "def train_step(W, b, x1, x2, y, learning_rate=1e-2) -> (np.ndarray, np.ndarray, np.ndarray):\n",
    "    x = jnp.concatenate([x1, x2], axis=1)\n",
    "    loss_value, Wb_grad  = value_and_grad(loss, (0, 1))(W, b, x, y)\n",
    "    W -= learning_rate * Wb_grad[0]\n",
    "    b -= learning_rate * Wb_grad[1]\n",
    "    return loss_value, W, b"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eac6baac-280c-44b4-b8a0-9a6eb3a76908",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### 训练循环\n",
    "\n",
    "接下来，将上述的模型定义、损失函数、优化器组合成训练循环。在每个迭代中，我们将所有样本输入模型，更新参数，记录损失函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f15b0c71-5761-41aa-b70d-9ba242f8a795",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def fit(W, b, x1, x2, y, epochs=1, learning_rate=1e-2):    \n",
    "    losses = jnp.array([])\n",
    "    for _ in range(epochs):\n",
    "        l, W, b = train_step(W, b, x1, x2, y, learning_rate=learning_rate)\n",
    "        losses = jnp.append(losses, l)\n",
    "    return losses, W, b"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74655110-6df8-4da9-8e33-b9b87a0a027d",
   "metadata": {},
   "source": [
    "#### 指标可视化\n",
    "\n",
    "我们可以观察训练集上的损失曲线变化，以进行参数调优。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "61b1837c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_losses(losses):\n",
    "    plt.plot(np.arange(len(losses)), losses)\n",
    "    plt.xlabel('epoch')\n",
    "    plt.ylabel('loss')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dd8abd5",
   "metadata": {},
   "source": [
    "### 验证模型\n",
    "\n",
    "让我们来观察一下训练集上的准确率和AUC。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "132fcee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "def validate_model(W,b,x,y):\n",
    "    y_pred = predict(W, b, jnp.concatenate([x1, x2], axis=1))\n",
    "    auc = roc_auc_score(y, y_pred)\n",
    "    acc = jnp.mean((y_pred > 0.5) == y)\n",
    "    print(f'auc={auc}, acc={acc}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "581fac73",
   "metadata": {},
   "source": [
    "## Build Together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ad002d29",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc=0.9838539189260611, acc=0.9349736571311951\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEKCAYAAAAIO8L1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAApnElEQVR4nO3dd3gVZfr/8fedRpMmBEV6SaQjEhAITUFERbCBoKKoKzZQrKvr+t39qdvsoigg9g4oiK6KqHQIEBCkSZUSRAhIbyFw//44x93IBgyQk0n5vK7rXGaemTnnPkdyPplnZp7H3B0REZEjRQVdgIiI5E8KCBERyZYCQkREsqWAEBGRbCkgREQkWwoIERHJVkQDwsy6mtkyM1tpZg9ms/5ZM5sffiw3s+1Z1l1vZivCj+sjWaeIiPwvi9R9EGYWDSwHzgfSgDlAH3dfcpTtBwLN3P1GMzsVSAWSAAfmAs3dfVtEihURkf8RySOIlsBKd1/t7hnAB0CPY2zfB3g//PMFwAR3/yUcChOArhGsVUREjhATweeuAqzPspwGnJPdhmZWA6gFfHuMfasc68UqVqzoNWvWPNFaRUSKpLlz525x9/js1kUyII5Hb2C0ux86np3MrD/QH6B69eqkpqZGojYRkULLzNYebV0ku5g2ANWyLFcNt2WnN//tXsrxvu4+3N2T3D0pPj7bABQRkRMUyYCYAySYWS0ziyMUAuOO3MjM6gHlgZlZmscDXcysvJmVB7qE20REJI9ErIvJ3TPNbAChL/Zo4DV3X2xmjwKp7v5rWPQGPvAsl1O5+y9m9hihkAF41N1/iVStIiLyvyJ2mWteS0pKcp2DEBE5PmY2192TslunO6lFRCRbCggREcmWAkJERLJV5APi8GHn758vZf0ve4MuRUQkXynyAbFm6x4+mL2OHkOmM2v11qDLERHJN4p8QNSOP4WxdyRTrmQs14yYxfuz1wVdkohIvlDkAwJCITHm9mTa1K3IQx8v5K/jFpN56HDQZYmIBEoBEVa2RCyvXZ/ETW1r8caMNdzwxhx27D0YdFkiIoFRQGQREx3FI90a8MQVTUhZvZVLX5rOqvTdQZclIhIIBUQ2erWoxns3t2LnvoNcOmQ6k5enB12SiEieU0AcRYuap/LJgGSqli/JDa/P5tVpP1JYhiUREckJBcQxVC1fktG3tqZLg9N57LMl/PGj7zmQeVxTVoiIFFgKiN9RqlgML11zNnd2SmBkahrXjpjFlt0Hgi5LRCTiFBA5EBVl3HN+Ii9e3YyFG3bQ48XpLPlpZ9BliYhElALiOHRrcgajbmnDocPOFS/P4MtFPwddkohIxCggjlPjqmUZNyCZM08vza3vzGXwNyt08lpECiUFxAmoVKY4H/RvxeXNqvDMhOUMfP879mXo5LWIFC4Rm3K0sCseG83TvZqSeHpp/vXlD6zdupfh1zWnctkSQZcmIpIrdARxEsyMWzvUYcR1Sfy4ZQ/dX5zOd+u2BV2WiEiuiGhAmFlXM1tmZivN7MGjbNPLzJaY2WIzey9L+xPhtqVmNtjMLJK1noxO9U/j49vbUCI2mquGpzDmu7SgSxIROWkRCwgziwaGABcCDYA+ZtbgiG0SgIeAZHdvCAwKt7cBkoEmQCOgBdAhUrXmhsTTSvPJHcmcXb0cd3+4gH9+8QOHDuvktYgUXJE8gmgJrHT31e6eAXwA9Dhim5uBIe6+DcDdN4fbHSgOxAHFgFhgUwRrzRXlS8Xx9k3ncG2r6gydvIr+b6Wya79GhBWRgimSAVEFWJ9lOS3cllUikGhm080sxcy6Arj7TGAisDH8GO/uSyNYa66JjY7i8Usb81iPhkxans4VL89g3VZNZyoiBU/QJ6ljgASgI9AHeMXMyplZXaA+UJVQqJxnZu2O3NnM+ptZqpmlpqfnrxFX+7auyds3tmTTzgN0HzKNGau2BF2SiMhxiWRAbACqZVmuGm7LKg0Y5+4H3f1HYDmhwLgMSHH33e6+G/gCaH3kC7j7cHdPcvek+Pj4iLyJk9GmbkU+uSOZiqcU47pXZ/NOytqgSxIRybFIBsQcIMHMaplZHNAbGHfENmMJHT1gZhUJdTmtBtYBHcwsxsxiCZ2gLhBdTEeqWbEUH9/ehnYJFfnz2EU8MnYRBzWdqYgUABELCHfPBAYA4wl9uY9098Vm9qiZdQ9vNh7YamZLCJ1zuN/dtwKjgVXAQmABsMDdP41UrZFWpngsI65vwS3ta/N2ylquf2022/dmBF2WiMgxWWEZRygpKclTU1ODLuN3fTQ3jYc+XkjlcsV59fok6lYqHXRJIlKEmdlcd0/Kbl3QJ6mLnCuaV+X9/q3Yc+AQlw2ZwcQfNv/+TiIiAVBABKB5jfKMG5BM9QolufHNOQyfskojwopIvqOACMgZ5Uow6tbWXNSoMn///AfuG6XpTEUkf1FABKhkXAwvXt2Muzsn8tG8NPoMT2Hzzv1BlyUiAiggAmdm3NU5gZevOZulG3dx0eBppKzeGnRZIiIKiPziwsaV+WRAMmVKxHDNiFkMm6zzEiISLAVEPpJ4WmnGDWjLBQ1P4x9f/MCt78xlpwb7E5GAKCDymVOKxTDk6rN5pFsDvlm6me4vTGPpxp1BlyUiRZACIh8yM25qW4v3+7dib8YhLntpOh/P0yREIpK3FBD5WIuap/LZnW05q1o57hm5gIfHLNSlsCKSZxQQ+Vyl0sV556ZzuLVDHd6dtY6eQ2eStk3zS4hI5CkgCoCY6CgevLAew/o258f0PXR7YRqTlmmIDhGJLAVEAXJBw9MZN7Atp5cpzg1vzOG5r5dzWPNei0iEKCAKmFoVSzHm9mQua1aF575ewQ1vzGHbHg0dLiK5TwFRAJWIi+bpnk35+2WNmblqK91emMaC9duDLktEChkFRAFlZlx9TnVG3xaaibXn0Jm8O2ut7r4WkVyjgCjgmlQtx2cD29K6TgUeHrOIe0ctYF+GLoUVkZOngCgEypeK4/V+Lbi7cyJjvtvAZS9N58cte4IuS0QKOAVEIREVFRoV9o0bWvLzzv10f2Ea4xf/HHRZIlKAKSAKmQ6J8Xw2sC2140txy9tz+ccXS8k8dDjoskSkAIpoQJhZVzNbZmYrzezBo2zTy8yWmNliM3svS3t1M/vKzJaG19eMZK2FSdXyJRl5a2uubVWdYZNXc82IWWzepYmIROT4RCwgzCwaGAJcCDQA+phZgyO2SQAeApLdvSEwKMvqt4An3b0+0BLQrcPHoVhMNI9f2phnejVlQdp2ug2expw1vwRdlogUIJE8gmgJrHT31e6eAXwA9Dhim5uBIe6+DcDdNwOEgyTG3SeE23e7uwYgOgGXn12VsXckUzIumt7DUxgxdbUuhRWRHIlkQFQB1mdZTgu3ZZUIJJrZdDNLMbOuWdq3m9nHZvadmT0ZPiL5DTPrb2apZpaanp4ekTdRGNQ7vQzjBralc/1KPP7vpdzx3jx2aSIiEfkdQZ+kjgESgI5AH+AVMysXbm8H3Ae0AGoD/Y7c2d2Hu3uSuyfFx8fnUckFU5nisQy9tjl/uqge4xdvoseQ6SzftCvoskQkH4tkQGwAqmVZrhpuyyoNGOfuB939R2A5ocBIA+aHu6cygbHA2RGstUgwM/q3r8O7fziHnfsy6fHidD6Zf+T/EhGRkEgGxBwgwcxqmVkc0BsYd8Q2YwkdPWBmFQl1La0O71vOzH49LDgPWBLBWouUVrUr8PmdbWlUpQx3fTCfv3yyiIxMXQorIr8VsYAI/+U/ABgPLAVGuvtiM3vUzLqHNxsPbDWzJcBE4H533+ruhwh1L31jZgsBA16JVK1FUaUyxXnv5lbc3K4Wb85cS69hM/lp+76gyxKRfMQKyxUtSUlJnpqaGnQZBdLnCzfywOjviYuJYnDvZrRNqBh0SSKSR8xsrrsnZbcu6JPUkg9c1LgynwxIpuIpcfR9bRaDv1nBIU1EJFLkKSAEgDrxpzD2jmS6Nz2DZyYsp++rs9i0U3dfixRlCgj5j5JxMTx31Vn864rGfLduO12fm8I3SzcFXZaIBEQBIb9hZlzVojqfDmxL5bIluOnNVP46bjH7D2qOCZGiRgEh2apb6RTG3NGGG5Nr8caMNVw6ZDorN+vGOpGiRAEhR1UsJpr/u6QBr/VLYvOuA3R7YRrvz16nsZxEiggFhPyu8+qdxpd3tSOpxqk89PFC7nhvHjv2aiwnkcJOASE5UqlMcd66sSUPXliPrxZv4qLBU0nV8OEihZoCQnIsKsq4tUMdRt/Whphoo9ewmTz/te6ZECmsFBBy3M6qVo7PBralx1lVePbr5fR5JUXDdIgUQgoIOSGli8fy7FVn8UyvpizesIMLn5/Kl4s2Bl2WiOQiBYSclMvPrsq/72xHjQolufWdefxpzEL2ZeieCZHCQAEhJ61mxVKMvrUNt3SozXuz1tH9xWn88PPOoMsSkZOkgJBcERcTxUMX1uetG1uybe9Bur84nbdmrtE9EyIFmAJCclX7xHi+HNSONnUq8H+fLObmt+aybU9G0GWJyAlQQEiuq3hKMV67vgWPdGvA5OWbufD5qcxctTXoskTkOCkgJCKiooyb2tZizO3JlIyL5uoRKTw1fhkHD2lqU5GCQgEhEdWoSlk+HdiWns2r8uLElVw1bCbrf9kbdFkikgMKCIm4UsVieOLKprzQpxkrNu3mouen8umCn4IuS0R+R0QDwsy6mtkyM1tpZg8eZZteZrbEzBab2XtHrCtjZmlm9mIk65S8cUnTM/j8rnbUPe0UBr7/HQ+MXsDejMygyxKRo4hYQJhZNDAEuBBoAPQxswZHbJMAPAQku3tDYNART/MYMCVSNUreq3ZqSUbe0poB59Zl1Nw0ug2exqINO4IuS0SyEckjiJbASndf7e4ZwAdAjyO2uRkY4u7bANx9868rzKw5cBrwVQRrlADERkdx3wVn8u4fzmFPRiaXvzSDV6f9qHsmRPKZSAZEFWB9luW0cFtWiUCimU03sxQz6wpgZlHA08B9x3oBM+tvZqlmlpqenp6LpUteaFOnIl/c1Z72ifE89tkSbnhjDlt2Hwi6LBEJC/okdQyQAHQE+gCvmFk54Hbgc3dPO9bO7j7c3ZPcPSk+Pj7StUoEnFoqjleua85jPRoyY9VWuj43lSnLFfYi+UEkA2IDUC3LctVwW1ZpwDh3P+juPwLLCQVGa2CAma0BngKuM7N/RrBWCZCZ0bd1TcYNSKZ8yViue202//h8KRmZumdCJEiRDIg5QIKZ1TKzOKA3MO6IbcYSOnrAzCoS6nJa7e7XuHt1d69JqJvpLXfP9iooKTzqnV6GcQPacs051Rk2ZTVXvDyDFZt2BV2WSJEVsYBw90xgADAeWAqMdPfFZvaomXUPbzYe2GpmS4CJwP3urjEZirAScdH87bLGDL22ORu27+PiF6YxfMoqzVonEgArLFeOJCUleWpqatBlSC5K33WAh8cs5Kslm2hRszxP9WxKjQqlgi5LpFAxs7nunpTduqBPUoscVXzpYgzr25xnejXlh5930fW5qbydslaXw4rkEQWE5GtmxuVnV+Wru9uTVLM8j4xdxHWvzdYc2CJ5QAEhBULlsiV468aWPH5pI+au3cYFz05h9Nw0HU2IRJACQgoMM+PaVjX44q521K9chvtGLaD/23NJ36Wb60QiQQEhBU6NCqV4v38r/nxxfSYvT6fLs5P5fOHGoMsSKXQUEFIgRUcZf2hXm8/vbEu1U0ty+7vzuPP979i+V9ObiuQWBYQUaHUrleaj29pwz/mJfL5wI12encK3P2wKuiyRQkEBIQVebHQUd3ZKYOwdyZQvGceNb6Tyx9Hfs2v/waBLEynQFBBSaDSqUpZxA5O5rWMdRs1dT9fnpjJj1ZagyxIpsBQQUqgUi4nmj13rMerWNsTFRHH1K7P467jF7Ms4FHRpIgWOAkIKpeY1yvPvO9vSr01N3pixhosGT2Xu2m1BlyVSoOQoIMzsrvD80GZmr5rZPDPrEuniRE5GybgY/tq9Ie/94RwyMg/Tc+gM/vXlDxzI1NGESE7k9AjiRnffCXQBygN9Ac3PIAVCm7oV+XJQO3o2r8bLk1bR48XpmgdbJAdyGhAW/u9FwNvuvjhLm0i+V7p4LP+6sgmv9Uti654MLh0yncHfrCDzkCYlEjmanAbEXDP7ilBAjDez0oB+s6TAOa/eaXw1qD0XNa7MMxOWc8XLM1i5WZMSiWQnpwFxE/Ag0MLd9wKxwA0Rq0okgsqXimNwn2YMufps1v2yl4sGT2PE1NWalEjkCDkNiNbAMnffbmbXAn8G1IkrBdrFTSoz/u72tE+oyOP/Xkqf4Sms27o36LJE8o2cBsTLwF4zawrcC6wC3opYVSJ5pFLp4rxyXRJP9WzK0o076fr8FN6dpUmJRCDnAZHpod+YHsCL7j4EKB25skTyjplxZfOqjL+7PWdXL8/DY0KTEm3coUmJpGjLaUDsMrOHCF3e+m8ziyJ0HuKYzKyrmS0zs5Vm9uBRtullZkvMbLGZvRduO8vMZobbvjezq3L6hkRO1BnlQpMSPdajIalrttHl2Sl8PE+TEknRZTn5x29mpwNXA3PcfaqZVQc6uvtRu5nMLBpYDpwPpAFzgD7uviTLNgnASOA8d99mZpXcfbOZJQLu7ivM7AxgLlDf3bcf7fWSkpI8NTU1B29Z5Pet2bKH+0YtIHXtNjrXP43HLm1I5bIlgi5LJNeZ2Vx3T8puXY6OINz9Z+BdoKyZdQP2HyscwloCK919tbtnAB8Q6qLK6mZgiLtvC7/O5vB/l7v7ivDPPwGbgfic1CqSG2pWLMWHt7Tm4YvqM21lOuc/M4W3U9ZyWFc6SRGS06E2egGzgZ5AL2CWmV35O7tVAdZnWU4Lt2WVCCSa2XQzSzGzrtm8dksgjtCJ8SPX9TezVDNLTU9Pz8lbEcmx6Cjj5va1GT+oPWdVK8cjYxfRa9hMVmzSfRNSNOT0HMTDhO6BuN7dryN0dPBILrx+DJAAdAT6AK+YWblfV5pZZeBt4AZ3/58b89x9uLsnuXtSfLwOMCQyalQoxds3teSpnk1ZsXk3Fw+exnNfL9eYTlLo5TQgon7t/gnbmoN9NwDVsixXDbdllQaMc/eD7v4joXMWCQBmVgb4N/Cwu6fksE6RiPj1Sqdv7u1A10an89zXK+g2eBpz1/4SdGkiEZPTgPjSzMabWT8z60foi/vz39lnDpBgZrXMLA7oDYw7YpuxhI4eMLOKhLqcVoe3HwO85e6jc1ijSMRVPKUYg/s04/V+LdhzIJMrh87k/z5ZpNnrpFDK6Unq+4HhQJPwY7i7//F39skEBgDjgaXASHdfbGaPmln38Gbjga1mtgSYCNzv7lsJnedoD/Qzs/nhx1nH//ZEIuPcepX46p4OXN+6Jm+nrKXLs1P4eonmwpbCJUeXuRYEusxVgjJv3TYe+mghyzbt4uImlfnLJQ2oVLp40GWJ5MgJX+ZqZrvMbGc2j11mtjMy5YoULGdXL8+nA9ty7/mJTFi8ic5PT2bknPW6wU4KvGMGhLuXdvcy2TxKu3uZvCpSJL+Li4liYKcEPr+rHfVOL8MDH33PNSNmsWbLnqBLEzlhmpNaJBfVrXQKH/Rvxd8ua8TCtB1c8NwUXp60ioOamEgKIAWESC6LijKuOacGX9/bgXPPrMS/vvyBHi9OZ2GaRsiXgkUBIRIhp5UpztC+zRl6bXO27D5AjyHTePyzJezNyAy6NJEcUUCIRFjXRqcz4Z4O9G5ZnRHTfqTLs1OYslxDw0j+p4AQyQNlS8Ty98sa82H/VsTFRHHda7O558P5/LInI+jSRI5KASGSh86pXYHP72zHwPPqMm7BT3R+ZjKfzN+gS2IlX1JAiOSx4rHR3NvlTD67sy3VTy3JXR/Mp9/rc0jbpvmwJX9RQIgEpN7pZfjotjb85ZIGzFnzC12encKr037kkOackHxCASESoOgo44bkWky4pwPn1DqVxz5bwuUvTWfpRg1UIMFTQIjkA1XKleC1fi0Y3KcZadv2cckL03hy/A/sP6g5JyQ4CgiRfMLM6N70DL6+pwOXNqvCkImruOj5qaSs3hp0aVJEKSBE8pnypeJ4qmdT3rnpHDIPO72Hp/DQx9+zY6/mnJC8pYAQyafaJlRk/KD23NK+Nh/OWc95T09iVOp6DusktuQRBYRIPlYiLpqHLqrPpwPbUrNiKe4f/T09h81k8U8a10kiTwEhUgA0PKMso25pzZNXNmHNlj1c8sI0/vLJInbsU7eTRI4CQqSAiIoyeiZV49v7OtK3VQ3eTllLp6cnMXpumrqdJCIUECIFTNkSsfy/Ho34dGDoTuz7Ri1Qt5NEREQDwsy6mtkyM1tpZg8eZZteZrbEzBab2XtZ2q83sxXhx/WRrFOkIGp4RllG39rmN91Ofx23WN1OkmssUoOEmVk0sBw4H0gD5gB93H1Jlm0SgJHAee6+zcwquftmMzsVSAWSAAfmAs3dfdvRXi8pKclTU1Mj8l5E8rsdew/y9IRlvJOyllNLxfHQhfW5/OwqmFnQpUk+Z2Zz3T0pu3WRPIJoCax099XungF8APQ4YpubgSG/fvG7++Zw+wXABHf/JbxuAtA1grWKFGhlS8byaI9GjBsQ6na6d9QCeg6dyZKfNGSHnLhIBkQVYH2W5bRwW1aJQKKZTTezFDPrehz7isgRGlUJdTs9cWUTVm/ZQ7cXpqrbSU5Y0CepY4AEoCPQB3jFzMrldGcz629mqWaWmp6uGbpEIHS1U6+kaky8tyPXnFODt2auodPTk/hobprmnZDjEsmA2ABUy7JcNdyWVRowzt0PuvuPhM5ZJORwX9x9uLsnuXtSfHx8rhYvUtCVLRnLY5eGup2qhbudeg2bqZFiJcciGRBzgAQzq2VmcUBvYNwR24wldPSAmVUk1OW0GhgPdDGz8mZWHugSbhOR49SoSlk+urUNT1zRhFXpe+gWvtpp5351O8mxRSwg3D0TGEDoi30pMNLdF5vZo2bWPbzZeGCrmS0BJgL3u/tWd/8FeIxQyMwBHg23icgJiIoyerWoxrf3dqBPy2q8OXMN5z01mY/nqdtJji5il7nmNV3mKpJzC9N28Mgni5i/fjstapbn0R6NqF+5TNBlSQCCusxVRPKpxlXL8vFtbfjXFY1ZuXk33V6Yxv/7VN1O8lsKCJEiKirKuKpFdSbe15E+LavxxoxQt9OY79TtJCEKCJEirlzJOB6/tDGf3JFMlfIluPvDBVw1LIUfftbVTkWdAkJEAGhStRxjbmvDPy9vzIrNu7h48DQe/XSJup2KMAWEiPxHVJTRu2V1vr23I71bVOP1GT/S6enJjP1ug7qdiiAFhIj8j/Kl4vjbZaFupzPKFmfQh/O5ari6nYoaBYSIHFWTquUYc3sy/7i8MSs2hbqd/u+TRfyyJyPo0iQPKCBE5Jiioow+4W6nPi2r8U7KWjo+OZERU1eTkXk46PIkghQQIpIj5UuFrnb6clB7mlUvz+P/XkqXZyfz5aKfdX6ikFJAiMhxSTytNG/e2JI3bmhBbHQUt74zlz6vpLBog6Y8LWwUECJyQjqeWYkv7mrHYz0asnzTbi55cRr3jVrApp37gy5NcokCQkROWEx0FH1b12TifR25uV1tPpm/gXOfmsTgb1awL+NQ0OXJSVJAiMhJK1silj9dVJ+v7+lAh8R4npmwnPOensSY79I4fFjnJwoqBYSI5JoaFUrx8rXN+bB/KyqeUoy7P1zAZS/PIHWNRusviBQQIpLrzqldgU/uSObpnk35ecc+rhw6kzvem8f6X/YGXZocBwWEiEREVJRxRfOqTLyvI3d1SuCbpZvo9Mxk/vnFD+zS+E4FggJCRCKqZFwMd5+fyMT7OtKtSWWGTl5Fxycn8e6stWQe0o12+ZkCQkTyROWyJXim11mMG5BM7fhSPDxmERcPnsbUFelBlyZHoYAQkTzVpGo5Rt7SmpevOZu9BzPp++psbnh9Nis37wq6NDlCRAPCzLqa2TIzW2lmD2azvp+ZpZvZ/PDjD1nWPWFmi81sqZkNNjOLZK0iknfMjAsbV+brezrw0IX1SF2zjQuem8pfNBBgvhKxgDCzaGAIcCHQAOhjZg2y2fRDdz8r/BgR3rcNkAw0ARoBLYAOkapVRIJRLCaaWzrUYdL9oYEA39ZAgPlKJI8gWgIr3X21u2cAHwA9crivA8WBOKAYEAtsikiVIhK4CqcU+89AgGdlGQhw/GINBBikSAZEFWB9luW0cNuRrjCz781stJlVA3D3mcBEYGP4Md7dl0awVhHJBxJPK81bN7bk9RtaEBMdxS1vayDAIAV9kvpToKa7NwEmAG8CmFldoD5QlVConGdm7Y7c2cz6m1mqmaWmp+tKCJHC4twzK/FleCDAZT/v4pIXp3H/qAVs1kCAeSqSAbEBqJZluWq47T/cfau7HwgvjgCah3++DEhx993uvhv4Amh95Au4+3B3T3L3pPj4+Fx/AyISnF8HApx0/7nc3K42Y+dvoKMGAsxTkQyIOUCCmdUyszigNzAu6wZmVjnLYnfg126kdUAHM4sxs1hCJ6jVxSRSBGUdCLB9wn8HAhyZul432kVYxALC3TOBAcB4Ql/uI919sZk9ambdw5vdGb6UdQFwJ9Av3D4aWAUsBBYAC9z900jVKiL5X40KpRjaNzQQYKXSxXhg9Pd0fX6qTmRHkBWWDzYpKclTU1ODLkNE8oC78+Win3nyq2WsTt9Ds+rl+GPXerSqXSHo0gocM5vr7knZrQv6JLWIyHH79Ua7rwa155+XN2bj9v30Hp7C9a/N1hVPuUhHECJS4O0/eIg3Z6zhpUmr2LHvIN2bnsG9XRKpUaFU0KXle8c6glBAiEihsWPfQYZNXsVr038k85DTp2V1BnaqS6XSxYMuLd9SQIhIkbJ5534Gf7uCD2avJzY6ipva1qJ/h9qUKR4bdGn5jgJCRIqkNVv28PSE5Xy64CfKlYzljo516du6BsVjo4MuLd9QQIhIkbZoww6eGL+MKcvTqVy2OHd3TuTys6sQE63rdHQVk4gUaY2qlOWtG1vy3s3nUKlMcR746HsueG4KXy7aqHsojkEBISJFRps6FRl7exuGXhsa1efWd+Zx6UszmLFqS8CV5U8KCBEpUsyMro1OZ/yg9jxxRRM279zP1a/M4jrdQ/E/dA5CRIq0/QcP8fbMtQyZtJLtew9ySdMzuPf8RGpWLBr3UOgktYjI79i5/yDDJ6/m1Wk/cvDQYa5qUY27OiVQqUzhvodCASEikkObd+7nhW9X8v7sdcREGzcm1+KWDnUoW6Jw3kOhgBAROU5rt+7h6a+WM27BT5QtEcvtHetwfZuahe4eCgWEiMgJWvzTDp74chmTl6dzepniDOqcwJXNqxaaeyh0H4SIyAlqeEZZ3ryxJe/f3IrTyxbnwY8X0uW5KXyxsPDfQ6GAEBHJgdZ1KjDm9jYM69ucKDNue3cePYZMZ+IPmwttUCggRERyyMy4oGH4Hoorm7B1dwY3vDGHy16awaRlhS8odA5CROQEZWQe5qN5abz47Uo2bN9Hs+rluLtzIu0SKmJmQZeXIzpJLSISQRmZhxk1dz1Dvl3JTzv207xGeQZ1TqBt3fwfFAoIEZE8cCDzEKNS0xgycSUbd+wnqUZ5BnVOJLluhXwbFIFdxWRmXc1smZmtNLMHs1nfz8zSzWx++PGHLOuqm9lXZrbUzJaYWc1I1ioicrKKxURzbasaTLq/I4/1aEjatn1c++oseg2byYyVWwrcOYqIHUGYWTSwHDgfSAPmAH3cfUmWbfoBSe4+IJv9JwF/c/cJZnYKcNjd9x7t9XQEISL5zf6DhxiZup4hE1eyaecBWtY6lUGdE2hTp2LQpf1HUEcQLYGV7r7a3TOAD4AeOdnRzBoAMe4+AcDddx8rHERE8qPisdFc17omk+8/l79e0oA1W/Zw9SuzuGrYTFJWbw26vN8VyYCoAqzPspwWbjvSFWb2vZmNNrNq4bZEYLuZfWxm35nZk+Ejkt8ws/5mlmpmqenp6bn/DkREckHx2Gj6JddiygPn8pdLGrB6yx56D0+hz/AUZuXjoAj6PohPgZru3gSYALwZbo8B2gH3AS2A2kC/I3d29+HunuTuSfHx8XlTsYjICSoeG80NybWY+sC5PNKtASs27+aq4Slc/UoKc9b8EnR5/yOSAbEBqJZluWq47T/cfau7HwgvjgCah39OA+aHu6cygbHA2RGsVUQkzxSPjeamtqGg+PPF9Vm+aTc9h87k2hGzSM1HQRHJgJgDJJhZLTOLA3oD47JuYGaVsyx2B5Zm2becmf16WHAesAQRkUKkRFw0f2hXm6kPnMvDF9Xnh593cuXQmfR9dRZz124LurzIBUT4L/8BwHhCX/wj3X2xmT1qZt3Dm91pZovNbAFwJ+FuJHc/RKh76RszWwgY8EqkahURCVKJuGhubl+bKQ+cy58uqseSn3ZyxcszuO612cxbF1xQ6EY5EZF8Zm9GJm/NXMvwKav5ZU8GHRLjufv8RM6qVi7XX0t3UouIFEB7DvwaFKvYtvcg554Zz6DOiTTNxaBQQIiIFGC7D2Ty5ow1vDJ1Ndv3HuS8epUY1DmBJlXLnfRzKyBERAqBX4Ni+JTV7Nh3kM71K3FXp0QaVy17ws+pgBARKUR27T/IG9NDRxQ792dycePKvHh1sxMaEPBYARFz0pWKiEieKl08loGdErg+uSZvTF/DgcxDERktVgEhIlJAlSkey52dEiL2/EEPtSEiIvmUAkJERLKlgBARkWwpIEREJFsKCBERyZYCQkREsqWAEBGRbCkgREQkW4VmqA0zSwfWnsRTVAS25FI5BZ0+i9/S5/Fb+jz+qzB8FjXcPds5mwtNQJwsM0s92ngkRY0+i9/S5/Fb+jz+q7B/FupiEhGRbCkgREQkWwqI/xoedAH5iD6L39Ln8Vv6PP6rUH8WOgchIiLZ0hGEiIhkq8gHhJl1NbNlZrbSzB4Mup4gmVk1M5toZkvMbLGZ3RV0TUEzs2gz+87MPgu6lqCZWTkzG21mP5jZUjNrHXRNQTKzu8O/J4vM7H0zKx50TbmtSAeEmUUDQ4ALgQZAHzNrEGxVgcoE7nX3BkAr4I4i/nkA3AUsDbqIfOJ54Et3rwc0pQh/LmZWBbgTSHL3RkA00DvYqnJfkQ4IoCWw0t1Xu3sG8AHQI+CaAuPuG919XvjnXYS+AKoEW1VwzKwqcDEwIuhagmZmZYH2wKsA7p7h7tsDLSp4MUAJM4sBSgI/BVxPrivqAVEFWJ9lOY0i/IWYlZnVBJoBswIuJUjPAQ8AhwOuIz+oBaQDr4e73EaYWamgiwqKu28AngLWARuBHe7+VbBV5b6iHhCSDTM7BfgIGOTuO4OuJwhm1g3Y7O5zg64ln4gBzgZedvdmwB6gyJ6zM7PyhHobagFnAKXM7Npgq8p9RT0gNgDVsixXDbcVWWYWSygc3nX3j4OuJ0DJQHczW0Oo6/E8M3sn2JIClQakufuvR5SjCQVGUdUZ+NHd0939IPAx0CbgmnJdUQ+IOUCCmdUyszhCJ5nGBVxTYMzMCPUxL3X3Z4KuJ0ju/pC7V3X3moT+XXzr7oXuL8SccvefgfVmdma4qROwJMCSgrYOaGVmJcO/N50ohCftY4IuIEjunmlmA4DxhK5CeM3dFwdcVpCSgb7AQjObH277k7t/HlxJko8MBN4N/zG1Grgh4HoC4+6zzGw0MI/Q1X/fUQjvqtad1CIikq2i3sUkIiJHoYAQEZFsKSBERCRbCggREcmWAkJERLKlgBDJB8yso0aMlfxGASEiItlSQIgcBzO71sxmm9l8MxsWni9it5k9G54b4Bsziw9ve5aZpZjZ92Y2Jjx+D2ZW18y+NrMFZjbPzOqEn/6ULPMtvBu+Q1ckMAoIkRwys/rAVUCyu58FHAKuAUoBqe7eEJgM/CW8y1vAH929CbAwS/u7wBB3b0po/J6N4fZmwCBCc5PUJnRnu0hgivRQGyLHqRPQHJgT/uO+BLCZ0HDgH4a3eQf4ODx/Qjl3nxxufxMYZWalgSruPgbA3fcDhJ9vtrunhZfnAzWBaRF/VyJHoYAQyTkD3nT3h37TaPbIEdud6Pg1B7L8fAj9fkrA1MUkknPfAFeaWSUAMzvVzGoQ+j26MrzN1cA0d98BbDOzduH2vsDk8Ex9aWZ2afg5iplZybx8EyI5pb9QRHLI3ZeY2Z+Br8wsCjgI3EFo8pyW4XWbCZ2nALgeGBoOgKyjn/YFhpnZo+Hn6JmHb0MkxzSaq8hJMrPd7n5K0HWI5DZ1MYmISLZ0BCEiItnSEYSIiGRLASEiItlSQIiISLYUECIiki0FhIiIZEsBISIi2fr/Ex7TM/4GhwIAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "# 1. 产生数据及预处理\n",
    "x1, _ = load_dataset(return_label=False)\n",
    "x2, y = load_dataset(return_label=True)\n",
    "\n",
    "x1 = transform(x1)\n",
    "x2 = transform(x2)\n",
    "\n",
    "# 2.训练模型\n",
    "W = jnp.zeros((30,))\n",
    "b = 0.0\n",
    "epochs = 10\n",
    "learning_rate = 1e-2\n",
    "\n",
    "losses, W, b = fit(W, b, x1, x2, y, epochs=10, learning_rate=1e-2)\n",
    "\n",
    "# 3.指标可视化\n",
    "plot_losses(losses)\n",
    "\n",
    "# 4. 验证模型\n",
    "validate_model(W,b, jnp.concatenate([x1, x2], axis=1), y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3c9f6881-ed3e-487f-9eb2-c7c0dfa253c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "sf.init(['alice', 'bob'], num_cpus=8, log_to_driver=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8fe0ded7-d861-4506-9b87-4342629a8c5e",
   "metadata": {},
   "source": [
    "我们在物理设备上虚拟出三个逻辑设备\n",
    "- alice, bob：PYU设备，负责参与方本地的明文计算\n",
    "- ppu：PPU设备，由alice和bob构成，负责两方的密文计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "14c4f276-7eb0-4dd9-a39b-06b588bcc5bf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "alice, bob = sf.PYU('alice'), sf.PYU('bob')\n",
    "ppu = sf.PPU(sf.utils.testing.cluster_def(['alice', 'bob']))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57742bf6-7eb5-490a-a2a7-3b5ac37536f0",
   "metadata": {},
   "source": [
    "接下来，我们需要将load_dataset调度到逻辑设备alice、bob执行。在隐语中，用户可以通过`dev(fn)`将函数`fn`调度到指定的逻辑设备`dev`，后者可以是PYU、PPU、TEE等逻辑设备。函数执行返回对应的设备对象(DeviceObject)，设备对象可以理解为数据指针，指向位于远程的数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "149864cc-ea35-4f3e-bb7a-2247ef835ea8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<secretflow.device.device.pyu.PYUObject at 0x7fa0a473aeb0>,\n",
       " <secretflow.device.device.pyu.PYUObject at 0x7fa2211a3f70>,\n",
       " <secretflow.device.device.pyu.PYUObject at 0x7fa2208f5ca0>)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1, _ = alice(load_dataset)(return_label=False)\n",
    "x2, y = bob(load_dataset)(return_label=True)\n",
    "\n",
    "x1, x2, y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b3893fc-5f6f-4508-a93b-1fee64b6806d",
   "metadata": {},
   "source": [
    "然后，我们对数据进行预处理，这有助于提升模型效果。这里我们使用sklearn对数据集进行均值方差归一化，使得各个特征符合标准正态分布。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "002e2e86-bacf-4ff0-b3d1-891e62e4ae2c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x1 = alice(transform)(x1)\n",
    "x2 = bob(transform)(x2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "019e3f74-cfa8-458b-b3ca-6d3431d0e64d",
   "metadata": {},
   "source": [
    "在开始训练之前，我们将逻辑回归模型的参数W, b初始化为0，连同x1, x2, y一起传输到PPU上。这是因为每个设备都有自己的数据存储方式，在其上运行的函数只接受对应类型的DeviceObject作为输入。隐语提供了以下两种设备传输方式：\n",
    "- secretflow.to: 将一个python object或者device object传输到指定的设备\n",
    "- DeviceObject.to: 将一个device object传输到指定的设备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a899ee65-a963-4f75-b5bf-e97ed885b52f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(PPURuntime pid=63880)\u001b[0m [2022-03-08 20:06:07.984] [info] [context.cc:58] connecting to mesh, id=root, self=0\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=63881)\u001b[0m [2022-03-08 20:06:07.964] [info] [context.cc:58] connecting to mesh, id=root, self=1\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=63881)\u001b[0m [2022-03-08 20:06:07.980] [info] [context.cc:83] try_connect to rank 0 not succeed, sleep_for 1000ms and retry.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(PPURuntime pid=63881)\u001b[0m I0308 20:06:07.964183 63881 external/com_github_brpc_brpc/src/brpc/server.cpp:1046] Server[ppu::link::internal::ReceiverServiceImpl] is serving on port=21397.\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=63881)\u001b[0m I0308 20:06:07.964281 63881 external/com_github_brpc_brpc/src/brpc/server.cpp:1049] Check out http://k69b13338.eu95sqa:21397 in web browser.\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=63880)\u001b[0m I0308 20:06:07.984158 63880 external/com_github_brpc_brpc/src/brpc/server.cpp:1046] Server[ppu::link::internal::ReceiverServiceImpl] is serving on port=56300.\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=63880)\u001b[0m I0308 20:06:07.984239 63880 external/com_github_brpc_brpc/src/brpc/server.cpp:1049] Check out http://k69b13338.eu95sqa:56300 in web browser.\n"
     ]
    }
   ],
   "source": [
    "device = ppu\n",
    "\n",
    "W = jnp.zeros((30,))\n",
    "b = 0.0\n",
    "\n",
    "W_, b_, x1_, x2_, y_ = sf.to(device, W), sf.to(device, b), x1.to(device), x2.to(device), y.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71fabe33-8a83-4cd7-97ef-66743ce4d492",
   "metadata": {},
   "source": [
    "以上准备工作就绪之后，我们就可以在PPU上开始训练了。训练完成之后，损失值和模型参数是存储在PPU上的PPUObject类型对象。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d2b3a1a2-f12d-4fe7-bea5-5a6e4faf01ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(PPURuntime pid=63881)\u001b[0m I0308 20:06:08.065062 64226 external/com_github_brpc_brpc/src/brpc/socket.cpp:2202] Checking Socket{id=0 addr=127.0.0.1:56300} (0x7f3caef0bb40)\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=63881)\u001b[0m I0308 20:06:08.065236 64226 external/com_github_brpc_brpc/src/brpc/socket.cpp:2262] Revived Socket{id=0 addr=127.0.0.1:56300} (0x7f3caef0bb40) (Connectable)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(<secretflow.device.device.ppu.PPUObject at 0x7fa22078e6d0>,\n",
       " <secretflow.device.device.ppu.PPUObject at 0x7fa22078e7f0>,\n",
       " <secretflow.device.device.ppu.PPUObject at 0x7fa22078e820>)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "losses, W_, b_ = device(fit, static_argnames=['epochs'])(W_, b_, x1_, x2_, y_, epochs=10, learning_rate=1e-2)\n",
    "\n",
    "losses, W_, b_"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea3a5d7b-f16d-4330-a96f-5a4f6b0a2996",
   "metadata": {},
   "source": [
    "### 指标可视化\n",
    "\n",
    "我们可以观察训练集上的损失曲线变化，以进行参数调优。但是`losses`此时为PPUObject，我们需要将它转换为python object。隐语提供了`sf.reveal`将任意类型的DeviceObject转换为python object。\n",
    "\n",
    "> 请谨慎使用`sf.reveal`，否则有可能造成数据隐私泄漏。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8a725cc8-21f4-4e02-ae77-a10f2dcd1187",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(PPURuntime pid=63880)\u001b[0m [2022-03-08 20:06:08.980] [info] [context.cc:111] connected to mesh, id=root, self=0\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=63881)\u001b[0m [2022-03-08 20:06:08.980] [info] [context.cc:111] connected to mesh, id=root, self=1\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=63880)\u001b[0m 20:06:14 TRACE: [Profiling] PPU execution completed, input processing took 0.004173861s, execution took 5.960133054s, output processing took 4.4542e-05s, total time 5.964351457s.\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=63881)\u001b[0m 20:06:14 TRACE: [Profiling] PPU execution completed, input processing took 0.004294032s, execution took 5.961247655s, output processing took 3.8232e-05s, total time 5.965579919s.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Text(0, 0.5, 'loss')"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEJCAYAAACOr7BbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAApMklEQVR4nO3deXRU9f3/8ec7G2FfgyKLgAQFEUQDKGETFMENRUVBFDesFURcq21tv9Xa9udSV0QBN6yKFBVQW4FaZN+CAkoQCIsQQAn7Dkl4//6YsY1pwACZ3Enyepwz5+R+7ufOvGdOMq/cz733c83dERERyS8m6AJERCQ6KSBERKRACggRESmQAkJERAqkgBARkQIpIEREpEARDQgz62Fmy80sw8weLmD9s2a2KPxYYWY78qwbYGYrw48BkaxTRET+l0XqOggziwVWABcBmcACoK+7px+h/91Aa3e/1cxqAGlACuDAQuBcd98ekWJFROR/xEXwudsCGe6+GsDMxgC9gAIDAugL/D7888XAFHffFt52CtADeO9IL1arVi1v2LBh0VQuIlJGLFy4cIu7JxW0LpIBURdYn2c5E2hXUEczOxVoBPz7KNvWPdqLNWzYkLS0tOMuVkSkLDKz7460LloOUl8PjHP33GPZyMzuMLM0M0vLysqKUGkiImVTJANiA1A/z3K9cFtBruenw0eF2tbdR7h7irunJCUVuIckIiLHKZIBsQBINrNGZpZAKAQm5u9kZmcA1YE5eZonAd3NrLqZVQe6h9tERKSYROwYhLvnmNlgQl/sscDr7r7UzB4D0tz9x7C4HhjjeU6ncvdtZvY4oZABeOzHA9YiIlI8Inaaa3FLSUlxHaQWETk2ZrbQ3VMKWhctB6lFRCTKKCBERKRAZT4gDh92/vSPZazfti/oUkREokqZD4i1W/cyZv46eg2bxbzVW4MuR0QkapT5gGicVInxg1KpViGe/q/NY8z8dUGXJCISFcp8QEAoJD66K5XzT6vFwx9+zR8+XkpO7uGgyxIRCZQCIqxq+XheH5DCbR0a8castdzy5gJ27ssOuiwRkcAoIPKIi43h0cua8+TVLZm7eitXvTyLVVl7gi5LRCQQCogC9GlTn3cHnsfO/dlcOWwW01ZoIkARKXsUEEfQpmENJgxOpV71Ctzyxnxem7mG0nLVuYhIYSggjqJe9QqMu/N8Lmp+Eo9/ks6vPljCwZxjmpFcRKTEUkD8jIrl4hh+w7kM6dqEsWmZ9B81jy17DgZdlohIxCkgCiEmxriv++m82Lc1SzJ30uulWaRv3BV0WSIiEaWAOAaXtzqFcXe2J/ewc80rs/nsm++DLklEJGIUEMforHpVmTg4leSTKnPn3xby4ucrdfBaREolBcRxqF0lkffvOI+rWtflmSkrGDJmEfsP6eC1iJQuEbujXGmXGB/LX/u0oulJlXly0res3bKXkTelcHLVxKBLExEpEtqDOAFmxi+7nMbIG1NYnbWHy1+ayVfrtgddlohIkVBAFIELm5/Eh3elkhgfw3Uj5jL+qw1BlyQicsIiGhBm1sPMlptZhpk9fIQ+fcws3cyWmtm7edqfDLctM7MXzMwiWeuJOv3kykwY1IHW9asx9P1F/L/PvuXwYR28FpGSK2IBYWaxwDCgJ9Ac6GtmzfP1SQYeAVLd/UxgaLi9PZAKtARaAG2AzpGqtajUqJjA27e1o1+7Bgz/YhV3vJ3GnoM5QZclInJcIrkH0RbIcPfV7n4IGAP0ytdnIDDM3bcDuPvmcLsDiUACUA6IB36IYK1FJiEuhieubMFjvc5k6vIser88i3VbdTtTESl5IhkQdYH1eZYzw215NQWamtksM5trZj0A3H0OMBXYFH5McvdlEay1SJkZN53fkNG3tuWHXQfpNWwmc3U7UxEpYYI+SB0HJANdgL7ASDOrZmZNgGZAPUKh0tXMOubf2MzuMLM0M0vLyoq+KblTm9RiwqBUalRMoP+oebw7T7czFZGSI5IBsQGon2e5Xrgtr0xgortnu/saYAWhwLgKmOvue9x9D/BP4Pz8L+DuI9w9xd1TkpKSIvImTlTDWhX5aFAqHZJr8euPvub3E77R7UxFpESIZEAsAJLNrJGZJQDXAxPz9RlPaO8BM6tFaMhpNbAO6GxmcWYWT+gAdYkZYsqvSmI8rw1owx2dGvPWnO8Y8MZ8duw7FHRZIiJHFbGAcPccYDAwidCX+1h3X2pmj5nZFeFuk4CtZpZO6JjDg+6+FRgHrAK+BhYDi93940jVWhxiY4xfX9KMp69txYI127ly2CwyNut2piISvay0TDSXkpLiaWlpQZdRKAu/28Yv3l7IwezDvNCvNRecXjvokkSkjDKzhe6eUtC6oA9Sl0nnnlqDCYM7UL9GBW57cwGjZqzWjLAiEnUUEAGpW6084355PhefeTJ//HQZD47T7UxFJLooIAJUISGOYf3O4Z5uyYxbmEm/kfPYvOtA0GWJiAAKiMDFxBj3XtSUYf3OIX3jLi59cSbzdFGdiEQBBUSUuLRlHcYPSqVyuTj6jZrHyOk6LiEiwVJARJHTT67MhMGpdG9+Ek/8Yxl3vfMluw9kB12WiJRRCogoUzkxnpdvOIffXNKMyek/0OulWSz/fnfQZYlIGaSAiEJmxsBOjXn39nbsPpjDlcNm6SZEIlLsFBBRrF3jmnx6dwfOqluVoe8v4ncTvtGpsCJSbBQQUa52lUTeGdiOOzo1ZvSc77ju1bls3LE/6LJEpAxQQJQA8bEx/PqSZrzS/xwyNu/h0hdmMGNl9E1vLiKliwKiBOnRog4TB6dSu3IiN70+n5f+vVL3vRaRiFFAlDCNkyrx0aD29Gp1Ck9PXsHto9PYuU+nwopI0VNAlEAVEuJ49rqzefzKFsxYmcWlL87gmw07gy5LREoZBUQJZWbceN6pjP3F+Rw+7PQePpv3F+iWpiJSdBQQJVzrBtX5ZEhH2jWqwa8++JqHxi3mQLZOhRWRE6eAKAVqVEzgzVvaMqRbMmPTMun98mzWbd0XdFkiUsIpIEqJ2Bjjvoua8sbNbdiwYz+XvTiDf6X/EHRZIlKCKSBKmQvOqM0nd3egQc0K3D46jacmfUuuToUVkeMQ0YAwsx5mttzMMszs4SP06WNm6Wa21MzezdPewMwmm9my8PqGkay1NKlfowLj7mxP37b1GTZ1FTe9Po8tew4GXZaIlDARCwgziwWGAT2B5kBfM2uer08y8AiQ6u5nAkPzrB4NPOXuzYC2wOZI1VoaJcbH8ufeLXnympakrd3OZS/MZOF324MuS0RKkEjuQbQFMtx9tbsfAsYAvfL1GQgMc/ftAO6+GSAcJHHuPiXcvsfdddT1OPRJqc+Hd7UnIS6G616dw5uz1uhGRCJSKJEMiLrA+jzLmeG2vJoCTc1slpnNNbMeedp3mNmHZvaVmT0V3iOR43DmKVX5+O4OdDk9if/7OJ17xixi78GcoMsSkSgX9EHqOCAZ6AL0BUaaWbVwe0fgAaAN0Bi4Of/GZnaHmaWZWVpWliavO5qq5eMZcWMKD158Op8s2ciVw2aRsXlP0GWJSBSLZEBsAOrnWa4XbssrE5jo7tnuvgZYQSgwMoFF4eGpHGA8cE7+F3D3Ee6e4u4pSUlJkXgPpUpMjDHogia8fVs7tu09RK+XZvLpkk1BlyUiUSqSAbEASDazRmaWAFwPTMzXZzyhvQfMrBahoaXV4W2rmdmP3/pdgfQI1lqmpDapxSdDOnD6yZUZ9O6XPP5JOtm5h4MuS0SiTMQCIvyf/2BgErAMGOvuS83sMTO7ItxtErDVzNKBqcCD7r7V3XMJDS99bmZfAwaMjFStZVGdquUZc8f53Ny+Ia/NXEPfEXP5YdeBoMsSkShipeWMlpSUFE9LSwu6jBJp4uKNPPzBEiokxPFSv9ac17hm0CWJSDExs4XunlLQuqAPUksUuKLVKUwYlErV8nHcMGoer0xbpRsRiYgCQkKST6rMhMEd6NHiZP7yz2+5+c0FZO3W1dciZZkCQv6jUrk4XurbmieuasG81Vvp+fx0pq3Q6cMiZZUCQn7CzLih3al8fHcHalYsx4DX5/PEp+kcytFZTiJljQJCCtT0pMpMGJzKTeefysgZa7h6+GzWbNkbdFkiUowUEHJEifGxPNarBSNuPJf12/dx6QszGLcwU3M5iZQRCgj5Wd3PPJl/3tORlvWq8sDfFzP0/UXsPpAddFkiEmEKCCmUOlXL887t5/FA96Z8smQTl7wwg6/WafpwkdJMASGFFhtjDO6azNhfnMfhw3DtK3N4+YsMXTMhUkopIOSYnXtqDf5xT0cubnEyT362nP6vzdM0HSKlkAJCjkvV8vG81Lc1T17dkq/W7aDHc9P5fNkPQZclIkVIASHHzczo06Y+H9/dgTpVy3PbW2n838SlHMjODbo0ESkCCgg5YU1qV+KjQe25NbURb85eG74Z0e6gyxKRE6SAkCJRLi6W313enNdvTmHz7oNc9uJM3pu/TtdMiJRgCggpUl3POInP7ulIyqk1eOTDrxn07pfs3KdrJkRKIgWEFLnaVRIZfWtbHu55BpOX/sAlL8wgbe22oMsSkWOkgJCIiIkx7ux8GuN+2Z64WKPPq3N4/l8rydU1EyIlhgJCIurs+tX45O4O9Dq7Ls/+awV9R85l4479QZclIoWggJCIq5wYz7PXnc1f+7Ri6Yad9Hx+Bp99synoskTkZyggpNj0Pqcenw7pyKk1K3Dn377k1x99zf5DumZCJFpFNCDMrIeZLTezDDN7+Ah9+phZupktNbN3862rYmaZZvZSJOuU4tOwVkXG3dmeX3RuzLvz1nHFSzP59vtdQZclIgWIWECYWSwwDOgJNAf6mlnzfH2SgUeAVHc/Exia72keB6ZHqkYJRkJcDI/0bMboW9uyfV82V7w0i9Fz1uqaCZEoE8k9iLZAhruvdvdDwBigV74+A4Fh7r4dwN03/7jCzM4FTgImR7BGCVCnpkl8NrQj7U+rye8mLGXg6IVs33so6LJEJCySAVEXWJ9nOTPclldToKmZzTKzuWbWA8DMYoBngAeO9gJmdoeZpZlZWlZWVhGWLsWlVqVyvD6gDY9e1pxpKzbT8/kZzFm1NeiyRITgD1LHAclAF6AvMNLMqgF3Af9w98yjbezuI9w9xd1TkpKSIl2rREhMjHFbh0Z8dFcqFRJi6TdqLk9PWk527uGgSxMp0yIZEBuA+nmW64Xb8soEJrp7truvAVYQCozzgcFmthZ4GrjJzP4SwVolCrSoW5WP7+7AtefW46WpGVz7yhxWZ+0JuiyRMiuSAbEASDazRmaWAFwPTMzXZzyhvQfMrBahIafV7n6Duzdw94aEhplGu3uBZ0FJ6VKxXBxPXtOKF/u2Zs2WvVzywgzemLVGd60TCUDEAsLdc4DBwCRgGTDW3Zea2WNmdkW42yRgq5mlA1OBB91dA9DC5a1OYfK9nTi/cU3+8HE6N4yaR+b2fUGXJVKmWGk5tTAlJcXT0tKCLkOKmLszNm09j32cjpnx6GXN6JNSHzMLujSRUsHMFrp7SkHrgj5ILXJUZsZ1bRrw2dBOtKhbhV998DW3vrlA98AWKQYKCCkR6teowLu3n8fvL2/OnNVb6f7sdCYu3qiL60QiSAEhJUZMjHFLaiP+MaQjjZMqMuS9rxj87lds08V1IhGhgJASp3FSJf7+i/N5qMfpTE7/nu7PTmNK+g9BlyVS6iggpESKi43hri5NmDi4A0mVExk4Oo37xy5m537d3lSkqCggpERrVqcKEwalcnfXJoxftIEez01nxkpNuyJSFBQQUuIlxMVwf/fT+eCX7amQEMuNr83nt+O/Zu/BnKBLEynRFBBSapxdvxqfDunI7R0a8c68dVzywgwWrN0WdFkiJZYCQkqVxPhYfntZc8YMPI/D7vR5dQ5PfJrOgWzduU7kWCkgpFRq17gmn93TiX5tGzByxhoue3EmSzJ3BF2WSIlSqIAws3vCt/80M3vNzL40s+6RLk7kRFQsF8cTV53FW7e2Zc+BHK56eTZ/nbKCQzmaRlykMAq7B3Gru+8CugPVgRsBTb8tJULnpklMGtqJXq1O4YXPV3LVy7NY/v3uoMsSiXqFDYgfZ0a7BHjb3ZfmaROJelUrxPPX687mlf7n8v3OA1z+4kxembaKXE0jLnJEhQ2IhWY2mVBATDKzyoD206XE6dHiZCbf24muZ9TmL//8lmtfmc2aLXuDLkskKhU2IG4DHgbauPs+IB64JWJViURQzUrlGN7/HJ677mwyNu+h5/PTeWv2Wt2USCSfwgbE+cByd99hZv2B3wI7I1eWSGSZGVe2rsvkezvTrlFNfj9xKf1f002JRPIqbEAMB/aZWSvgfmAVMDpiVYkUk5OrJvLmLW34c++zWLx+Bz2em8HYBes1jbgIhQ+IHA/9xfQCXnL3YUDlyJUlUnzMjL5tQzclOvOUKjz0wRJufyuNzbopkZRxhQ2I3Wb2CKHTWz81sxhCxyFESo36NSrw3sDzePSy5szM2EL356bz8eKNQZclEpjCBsR1wEFC10N8D9QDnvq5jcysh5ktN7MMM3v4CH36mFm6mS01s3fDbWeb2Zxw2xIzu66QdYqckJgY47YOjfh0SEdOrVmRu9/7il/+bSGbd2tvQsoeK+xYq5mdBLQJL853980/0z8WWAFcBGQCC4C+7p6ep08yMBbo6u7bzay2u282s6aAu/tKMzsFWAg0c/cdR3q9lJQUT0tLK9R7ESmMnNzDvDp9Nc9/vpLEuBh+fUkzrmtTHzNdAiSlh5ktdPeUgtYVdqqNPsB84FqgDzDPzK75mc3aAhnuvtrdDwFjCB3DyGsgMMzdtwP8GDruvsLdV4Z/3ghsBpIKU6tIUYmLjWHQBU345z0dOaNOFR7+8Gv6jpyr6yakzCjsENNvCF0DMcDdbyL05f/oz2xTF1ifZzkz3JZXU6Cpmc0ys7lm1iP/k5hZWyCB0JlTIsXutKRKjBl4Hn/ufRZLN+7i4uemM2xqBtm5ulZUSrfCBkRMviGlrcew7dHEAclAF6AvMNLMqv240szqAG8Dt7j7//w1mtkdZpZmZmlZWbqLmEROTEzoTKfP7+tMtzNq89Sk5Vz+4kwWr98RdGkiEVPYL/nPzGySmd1sZjcDnwL/+JltNgD18yzXC7fllQlMdPdsd19D6JhFMoCZVQm/zm/cfW5BL+DuI9w9xd1TkpI0AiWRV7tKIsP7n8urN57L9n2HuOrlWTz+STr7DunudVL6FCog3P1BYATQMvwY4e6/+pnNFgDJZtbIzBKA64GJ+fqMJ7T3gJnVIjTktDrc/yNgtLuPK9xbESk+F595MlPu60zftg14beYauj87nWkrtBcrpUuhz2I6ric3uwR4DogFXnf3J8zsMSDN3Sda6HSQZ4AeQC7whLuPCU/n8QawNM/T3ezui470WjqLSYIyf802HvlwCauy9nLl2afw6GXNqVmpXNBliRTK0c5iOmpAmNluoKAORug01CpFU+KJU0BIkA7m5DJs6iqGf5FBpXJxPHpZc65qXVenxErUO+7TXN29srtXKeBROZrCQSRo5eJiue+ipnw6pCMNa1XkvrGLuen1+azfpsn/pOTSPalFilDTkyoz7s72/OGKM/nyu+10f3Y6o2asJkenxEoJpIAQKWKxMcaA9g2Zcl9n2p9Wkz9+uozew2eTvnFX0KWJHBMFhEiEnFKtPKMGpPBi39Zs3LGfy1+ayf/77FsOZOcGXZpIoSggRCLIzLi81Sn8677OXH1OXYZ/sYoez01n9qotQZcm8rMUECLFoFqFBJ68phXv3N4OB/qNnMdD4xazc1920KWJHJECQqQYpTapxWf3dOLOzqfxwZcb6PbXaXy6ZJPuYCdRSQEhUszKJ8TycM8zmDAolTpVExn07pcMHJ3Gpp37gy5N5CcUECIBaVG3Kh/d1Z7fXNKMmRlbuOiv0xk9Zy2HD2tvQqKDAkIkQHGxMQzs1JjJQzvTukE1fjdhKde+OoeVP+wOujQRBYRINGhQswKjb23LX/u0YlXWHi55YQbPTlnBwRydEivBUUCIRAkzo/c59fj8vs5celYdnv98JZe+MJO0tduCLk3KKAWESJSpWakcz13fmjduacP+Q7lc88ocHh3/DbsP6JRYKV4KCJEodcHptZl8byduTW3EO/O+o+sz0xj/1QadEivFRgEhEsUqlovjd5c3Z/ygVE6pmsjQ9xdx3Yi5LP9eB7El8hQQIiVAy3rV+OiuVP7c+yxW/rCbS16YweOfpGvYSSJKASFSQsTEGH3bNuDf93fhujb1eX3WGg07SUQpIERKmOoVE/jTVWcxQcNOEmEKCJESKu+w04rwsNMfNewkRSiiAWFmPcxsuZllmNnDR+jTx8zSzWypmb2bp32Ama0MPwZEsk6RkurHYaep93ehT0p9Xpu1hm7PTGPCIg07yYmzSP0SmVkssAK4CMgEFgB93T09T59kYCzQ1d23m1ltd99sZjWANCAFcGAhcK67bz/S66WkpHhaWlpE3otISbF4/Q4enfANSzJ30q5RDR6/sgVNT6ocdFkSxcxsobunFLQuknsQbYEMd1/t7oeAMUCvfH0GAsN+/OJ3983h9ouBKe6+LbxuCtAjgrWKlAqt6oeGnf501Vks/2E3PZ/XsJMcv0gGRF1gfZ7lzHBbXk2BpmY2y8zmmlmPY9hWRAoQG2P0a6dhJzlxQR+kjgOSgS5AX2CkmVUr7MZmdoeZpZlZWlZWVmQqFCmhqldM4M+9z+Kju1I5uWoi94xZxPUj5rJCM8VKIUUyIDYA9fMs1wu35ZUJTHT3bHdfQ+iYRXIht8XdR7h7irunJCUlFWnxIqXF2fmGnS55fgZPfJrOnoM5QZcmUS6SAbEASDazRmaWAFwPTMzXZzyhvQfMrBahIafVwCSgu5lVN7PqQPdwm4gchx+Hnf59fxeuTanHqJlr6PbMFxp2kqOKWEC4ew4wmNAX+zJgrLsvNbPHzOyKcLdJwFYzSwemAg+6+1Z33wY8TihkFgCPhdtE5ATUqJjAn3u35MNftqd25dCwU9+RGnaSgkXsNNfiptNcRY5N7mFnzIJ1PPnZcvYezOHWDo0Y0i2ZSuXigi5NilFQp7mKSBSLjTFuaHcqUx/owjXn1mPE9NV0e+YLJi7eqGEnARQQImVejYoJ/OXqlnx0V2jYach7X9Fv5DzdF1sUECIS0rpBdcYPSuWPV7YgfdMuej4/gz/9Y5nOdirDFBAi8h+xMUb/8zTsJCEKCBH5Hz8OO314V3uSKpfTsFMZpYAQkSM6p0F1JgzqwONXtmDpxp30fH4Gf/h4KTv2HQq6NCkGCggROarYGOPG8LDTtSn1eGv2Wro8/QVvzFpDdu7hoMuTCFJAiEih1KxUjj/3bsmnQzrS4pSq/OHjdC5+djr/Sv9BxydKKQWEiByTZnWq8PZtbXltQAoY3D46jf6vzWPZpl1BlyZFTAEhIsfMzOjW7CQmDe3E/13enKUbd3HpCzN4+IMlbN59IOjypIgoIETkuMXHxnBzaiOmPXABN7dvxLiFmVzw1BcMm5rBgezcoMuTE6SAEJETVrVCPL+7vDmT7+1E+ya1eGrScro9M03XT5RwCggRKTKNkyox8qYU3r29HVXKxzPkva/oPXw2X6474u3kJYopIESkyLVvUotP7u7Ak1e3JHP7fnq/PJsh733Fhh37gy5NjoECQkQiIjbG6NOmPlMf6MLgC5owaen3dH36C56a9K3mdyohFBAiElGVysXxwMWn8+8HutCjxckMm7qKC57+gvcXrCP3sI5PRDMFhIgUi7rVyvP89a356K721K9enl998DWXvTiT2Rlbgi5NjkABISLFqnWD6nzwy/a82Lc1u/Zn02/UPG5/K43VWXuCLk3yUUCISLEzMy5vdQqf39+Zh3qcztzVW+n+7HRNBBhlIhoQZtbDzJabWYaZPVzA+pvNLMvMFoUft+dZ96SZLTWzZWb2gplZJGsVkeKXGB/LXV2a/GQiwM5PaSLAaBGxgDCzWGAY0BNoDvQ1s+YFdH3f3c8OP0aFt20PpAItgRZAG6BzpGoVkWAlVf7vRIBn1dVEgNEiknsQbYEMd1/t7oeAMUCvQm7rQCKQAJQD4oEfIlKliESNI00EmL5REwEGIZIBURdYn2c5M9yW39VmtsTMxplZfQB3nwNMBTaFH5PcfVkEaxWRKFHgRIAvaiLAIAR9kPpjoKG7twSmAG8BmFkToBlQj1CodDWzjvk3NrM7zCzNzNKysrKKsWwRibQfJwL84oEu3KKJAAMRyYDYANTPs1wv3PYf7r7V3Q+GF0cB54Z/vgqY6+573H0P8E/g/Pwv4O4j3D3F3VOSkpKK/A2ISPCqVUgocCLADxZm6kK7CItkQCwAks2skZklANcDE/N2MLM6eRavAH4cRloHdDazODOLJ3SAWkNMImXYfyYCHNiO6hXjuf/vi+n5/HSm6EB2xEQsINw9BxgMTCL05T7W3Zea2WNmdkW425DwqayLgSHAzeH2ccAq4GtgMbDY3T+OVK0iUnK0P60WEwd14KV+rcnOdQaOTuOaV+Ywf822oEsrday0JG9KSoqnpaUFXYaIFKPs3MP8PS2T5/61gs27D9L1jNo8ePHpNKtTJejSSgwzW+juKQWuU0CISEm3/1Aub85ey/AvMth9MIcrz67LvRc2pUHNCkGXFvUUECJSJuzcl83waat4Y9YaDrvTr20DBndNJqlyuaBLi1oKCBEpU77feYDnP1/J2LT1lIuL4fYOjRjYqTGVE+ODLi3qKCBEpExanbWHZ6as4NMlm6heIZ5BFzSh/3mnkhgfG3RpUUMBISJl2pLMHTw1aTkzVm6hbrXyDL0wmd7n1CM2RnOAHi0ggr6SWkQk4lrWq8bbt7XjndvbUbNSAg+OW0KP56Yzaen3uobiKBQQIlJmpDapxYRBqbx8wznkHnZ+8fZCeg+fzdzVW4MuLSopIESkTDEzLjmrDpPv7cRfep/Fph0HuH7EXAa8Pp+lG3cGXV5U0TEIESnTDmTn8tbstbz8xSp27s/milancH/3ppxas2LQpRULHaQWEfkZO/dn8+q0Vbw+aw05uU7ftg24u1sTaldODLq0iFJAiIgU0uZdB3jh3ysZM3898bEx3NahEXd0bkyVUnoNhQJCROQYrd2yl2emrODjxRupViGeQV2acOP5pe8aCgWEiMhx+mbDTp6ctJzpK7KoUzWRoRcmc/U59YiLLR3n+Og6CBGR49SiblVG39qWdwe2o3aVRH71wddc/Nx0PvtmU6m/hkIBISJSCO1Pq8X4u9rzSv/QjS/v/NuXXDlsFlOXby61QaGAEBEpJDOjR4uTmTS0E09e3ZItew5xyxsL6D18NtNWZJW6oNAxCBGR43Qo5zDjFmYybGoGG3bs55wG1Rh6YVM6JtfCrGTM86SD1CIiEXQwJ5e/p2Xy8tQMNu48wLmnVufeC5uS2qRm1AeFAkJEpBgczMllbDgoNu08QJuG1Rl6YVPanxa9QRHYWUxm1sPMlptZhpk9XMD6m80sy8wWhR+351nXwMwmm9kyM0s3s4aRrFVE5ESVi4vlxvNO5YsHu/BYrzNZv20/N4yax3WvzmX2qi1Bl3fMIrYHYWaxwArgIiATWAD0dff0PH1uBlLcfXAB238BPOHuU8ysEnDY3fcd6fW0ByEi0eZAdi7vL1jPy19k8MOug7RrVIOhFzbl/NNqBl3afwS1B9EWyHD31e5+CBgD9CrMhmbWHIhz9ykA7r7naOEgIhKNEuNjGdC+IdMevIDfX96cNVv20nfkXK4fMYd5JWCK8UgGRF1gfZ7lzHBbfleb2RIzG2dm9cNtTYEdZvahmX1lZk+F90hEREqcxPhYbkltxPSHLuB3lzVnVdZerhsxl74j5jJ/zbagyzuioK+D+Bho6O4tgSnAW+H2OKAj8ADQBmgM3Jx/YzO7w8zSzCwtKyureCoWETlOifGx3NqhETMeuoBHL2vOys176PPqHG4YNZcFa6MvKCIZEBuA+nmW64Xb/sPdt7r7wfDiKODc8M+ZwKLw8FQOMB44J/8LuPsId09x95SkpKSirl9EJCIS42O5LRwUv720Gcu/3821r8yh/6h5LPwueoIikgGxAEg2s0ZmlgBcD0zM28HM6uRZvAJYlmfbamb247d+VyAdEZFSpHxCLLd3bMyMh7rym0uasWzTLq4ePocbX5vHwu+2B11e5AIi/J//YGASoS/+se6+1MweM7Mrwt2GmNlSM1sMDCE8jOTuuYSGlz43s68BA0ZGqlYRkSCVT4hlYKfGzPjVBTzS8wyWbtzF1cNnc9Pr8/lqXXBBoQvlRESizN6DObw99ztGTF/Ntr2H6HJ6EkMvbMrZ9asV+WvpSmoRkRJo78EcRs/5jhHTV7F9XzYXhIOiVREGhQJCRKQE23Mwh7dmr2XkjNXs2JdN1zNqM/TCZFrWq3bCz62AEBEpBX4MihHTV7NzfzYXNqvNPd2acla9qsf9nAoIEZFSZPeBbN6cFdqj2HUgh0vPqsNL/Vof14SARwuIuBOuVEREilXlxHju7pbMgNSGvDlrLQdzciMyW6wCQkSkhKqSGM+QbskRe/6gp9oQEZEopYAQEZECKSBERKRACggRESmQAkJERAqkgBARkQIpIEREpEAKCBERKVCpmWrDzLKA707gKWoBW4qonJJOn8VP6fP4KX0e/1UaPotT3b3AW3KWmoA4UWaWdqT5SMoafRY/pc/jp/R5/Fdp/yw0xCQiIgVSQIiISIEUEP81IugCoog+i5/S5/FT+jz+q1R/FjoGISIiBdIehIiIFKjMB4SZ9TCz5WaWYWYPB11PkMysvplNNbN0M1tqZvcEXVPQzCzWzL4ys0+CriVoZlbNzMaZ2bdmtszMzg+6piCZ2b3hv5NvzOw9M0sMuqaiVqYDwsxigWFAT6A50NfMmgdbVaBygPvdvTlwHjCojH8eAPcAy4IuIko8D3zm7mcArSjDn4uZ1QWGACnu3gKIBa4PtqqiV6YDAmgLZLj7anc/BIwBegVcU2DcfZO7fxn+eTehL4C6wVYVHDOrB1wKjAq6lqCZWVWgE/AagLsfcvcdgRYVvDigvJnFARWAjQHXU+TKekDUBdbnWc6kDH8h5mVmDYHWwLyASwnSc8BDwOGA64gGjYAs4I3wkNsoM6sYdFFBcfcNwNPAOmATsNPdJwdbVdEr6wEhBTCzSsAHwFB33xV0PUEws8uAze6+MOhaokQccA4w3N1bA3uBMnvMzsyqExptaAScAlQ0s/7BVlX0ynpAbADq51muF24rs8wsnlA4vOPuHwZdT4BSgSvMbC2hoceuZva3YEsKVCaQ6e4/7lGOIxQYZdWFwBp3z3L3bOBDoH3ANRW5sh4QC4BkM2tkZgmEDjJNDLimwJiZERpjXubufw26niC5+yPuXs/dGxL6vfi3u5e6/xALy92/B9ab2enhpm5AeoAlBW0dcJ6ZVQj/3XSjFB60jwu6gCC5e46ZDQYmEToL4XV3XxpwWUFKBW4EvjazReG2X7v7P4IrSaLI3cA74X+mVgO3BFxPYNx9npmNA74kdPbfV5TCq6p1JbWIiBSorA8xiYjIESggRESkQAoIEREpkAJCREQKpIAQEZECKSBEooCZddGMsRJtFBAiIlIgBYTIMTCz/mY238wWmdmr4ftF7DGzZ8P3BvjczJLCfc82s7lmtsTMPgrP34OZNTGzf5nZYjP70sxOCz99pTz3W3gnfIWuSGAUECKFZGbNgOuAVHc/G8gFbgAqAmnufiYwDfh9eJPRwK/cvSXwdZ72d4Bh7t6K0Pw9m8LtrYGhhO5N0pjQle0igSnTU22IHKNuwLnAgvA/9+WBzYSmA38/3OdvwIfh+ydUc/dp4fa3gL+bWWWgrrt/BODuBwDCzzff3TPDy4uAhsDMiL8rkSNQQIgUngFvufsjP2k0ezRfv+Odv+Zgnp9z0d+nBExDTCKF9zlwjZnVBjCzGmZ2KqG/o2vCffoBM919J7DdzDqG228EpoXv1JdpZleGn6OcmVUozjchUlj6D0WkkNw93cx+C0w2sxggGxhE6OY5bcPrNhM6TgEwAHglHAB5Zz+9EXjVzB4LP8e1xfg2RApNs7mKnCAz2+PulYKuQ6SoaYhJREQKpD0IEREpkPYgRESkQAoIEREpkAJCREQKpIAQEZECKSBERKRACggRESnQ/wf46NMrGYGxFQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "losses = sf.reveal(losses)\n",
    "\n",
    "plt.plot(np.arange(len(losses)), losses)\n",
    "plt.xlabel('epoch')\n",
    "plt.ylabel('loss')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a58ff0dc-e2ac-4dfe-b562-4f3e42059c02",
   "metadata": {},
   "source": [
    "最后，让我们来观察一下训练集上的准确率和AUC。由于标签是在bob这一侧，因此我们在PPU上对训练样本进行打分，并将打分结果`y_pred`传输到bob，并在bob完成准确率和AUC的计算。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9bdffb27-4a44-4ee6-ab00-30d1b906fb38",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(PPURuntime pid=63880)\u001b[0m 20:06:15 TRACE: [Profiling] PPU execution completed, input processing took 0.000440146s, execution took 0.157971411s, output processing took 2.506e-05s, total time 0.158436617s.\n",
      "\u001b[2m\u001b[36m(PPURuntime pid=63881)\u001b[0m 20:06:15 TRACE: [Profiling] PPU execution completed, input processing took 0.000469365s, execution took 0.157679088s, output processing took 2.8232e-05s, total time 0.158176685s.\n",
      "auc=0.9838407060937583, acc=0.9349736571311951\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,546,546 DEBUG [xla_bridge.py:_init_backend:262] Initializing backend 'interpreter'\n",
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,550,550 DEBUG [xla_bridge.py:_init_backend:274] Backend 'interpreter' initialized\n",
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,550,550 DEBUG [xla_bridge.py:_init_backend:262] Initializing backend 'cpu'\n",
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,557,557 DEBUG [xla_bridge.py:_init_backend:274] Backend 'cpu' initialized\n",
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,557,557 DEBUG [xla_bridge.py:_init_backend:262] Initializing backend 'tpu_driver'\n",
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,558,558 INFO [xla_bridge.py:backends:247] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,558,558 DEBUG [xla_bridge.py:_init_backend:262] Initializing backend 'gpu'\n",
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,558,558 INFO [xla_bridge.py:backends:247] Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host\n",
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,558,558 DEBUG [xla_bridge.py:_init_backend:262] Initializing backend 'tpu'\n",
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,558,558 INFO [xla_bridge.py:backends:247] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n",
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,558,558 WARNING [xla_bridge.py:backends:252] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,561,561 DEBUG [dispatch.py:log_elapsed_time:184] Finished tracing + transforming _mean for jit in 0.001966714859008789 sec\n",
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,561,561 DEBUG [dispatch.py:lower_xla_callable:229] Compiling _mean (140442579858240 for args (ShapedArray(bool[569]),).\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "y_pred_ = device(lambda W, b, x1, x2: predict(W, b, jnp.concatenate([x1, x2], axis=1)))(W_, b_, x1_, x2_)\n",
    "y_pred = y_pred_.to(bob)\n",
    "\n",
    "auc = bob(roc_auc_score)(y, y_pred)\n",
    "acc = bob(lambda y_true, y_pred: jnp.mean((y_pred > 0.5) == y_true))(y, y_pred)\n",
    "\n",
    "\n",
    "print(f'auc={sf.reveal(auc)}, acc={sf.reveal(acc)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7912d021",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(_run pid=63879)\u001b[0m 2022-03-08 20:06:15,613,613 DEBUG [dispatch.py:log_elapsed_time:184] Finished XLA compilation of _mean in 0.04593181610107422 sec\n"
     ]
    }
   ],
   "source": [
    "import ray\n",
    "\n",
    "ray.shutdown()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33492f50-383d-401f-804e-68b2f94594f7",
   "metadata": {},
   "source": [
    "以上我们演示了如何利用PPU进行逻辑回归建模，你可以尝试将其应用于你的数据集。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
